package MipsCPU(sysMipsCPU, MROM, MRAM) where
import RegFile
import MipsFIFO
import List
import ActionSeq

import MipsInstr
import MipsDefs

-- The register file
type RegisterFile = RegFile CPUReg Value

-- The RAM
type MROM = ROM IAddress IValue

-- The ROM
type MRAM = RAM Address Value

--------

isCmpRegsOp :: Op -> Bool
isCmpRegsOp BEQ = True
isCmpRegsOp BNE = True
isCmpRegsOp _ = False

isCmpZeroImmOp :: Op -> Bool
isCmpZeroImmOp BLEZ = True
isCmpZeroImmOp BGTZ = True
isCmpZeroImmOp _ = False

isStoreOp :: Op -> Bool
isStoreOp SW = True
isStoreOp _ = False

isLoadOp :: Op -> Bool
isLoadOp LW = True
isLoadOp _ = False

isJumpFunct :: Funct -> Bool
isJumpFunct JR = True
isJumpFunct JALR = True
isJumpFunct _ = False

isShiftOp :: Funct -> Bool
isShiftOp SLL = True
isShiftOp SRL = True
isShiftOp SRA = True
isShiftOp _ = False

isCmpZeroRegImmOp :: REGIMM -> Bool
isCmpZeroRegImmOp BLTZ = True
isCmpZeroRegImmOp BGEZ = True
isCmpZeroRegImmOp _ = False

isCmpZALOp :: REGIMM -> Bool
isCmpZALOp BLTZAL = True
isCmpZALOp BGEZAL = True
isCmpZALOp _ = False


--------


-- Undecoded instructions
struct UInstr =
        pc :: IAddress
        instr :: IValue
    deriving (Bits)

-- Undecoded instruction buffer
type UBuffer = FIFO UInstr

-- ALU operations
data AOp
        = Aadd | Aand  | Asub | Anor | Aor  | Axor
        | Aslt | AsltU | Asll | Asra | Asrl
        deriving (Bits, Eq)

opToAOp :: Op -> AOp
opToAOp ADDI = Aadd
opToAOp ADDIU = Aadd
opToAOp SLTI = Aslt
opToAOp SLTIU = AsltU
opToAOp ANDI = Aand
opToAOp ORI = Aor
opToAOp XORI = Axor

functToAOp :: Funct -> AOp
functToAOp ADD = Aadd
functToAOp ADDU = Aadd
functToAOp SUB = Asub
functToAOp SUBU = Asub
functToAOp AND = Aand
functToAOp OR = Aor
functToAOp XOR = Axor
functToAOp NOR = Anor
functToAOp SLT = Aslt
functToAOp SLTU = AsltU
functToAOp SRAV = Asra
functToAOp SRLV = Asrl
functToAOp SLLV = Asll
functToAOp SRA = Asra
functToAOp SRL = Asrl
functToAOp SLL = Asll

------

compareROp :: Op -> Value -> Value -> Bool
compareROp BEQ v1 v2 = v1 == v2
compareROp BNE v1 v2 = v1 /= v2

compareZImmOp :: Op -> Value -> Bool
compareZImmOp BLEZ v = v `vLE` 0
compareZImmOp BGTZ v = v `vGT` 0

compareZRegImmOp :: REGIMM -> Value -> Bool
compareZRegImmOp BLTZ v = v `vLT` 0
compareZRegImmOp BGEZ v = v `vGE` 0

compareZALOp :: REGIMM -> Value -> Bool
compareZALOp BLTZAL v = v `vLT` 0
compareZALOp BGEZAL v = v `vGE` 0


------

-- Decoded instruction, i.e., exec stage operation
data DInstr
        = Oload  {              dest  :: CPUReg; addr :: Value; offs :: Value }
        | Ostore {              datum :: Value;  addr :: Value; offs :: Value }
        | Oarith { aop :: AOp;  dest  :: CPUReg; v1   :: Value; v2   :: Value }
        | Owback {                             dest :: CPUReg; datum :: Value }
    deriving (Bits)

type DBuffer = FoldFIFO DInstr

-- Memory stage operation
data MInstr
        = Mload  { dest :: CPUReg; addr  :: Value; }
        | Mstore { datum :: Value; addr  :: Value; }
        | Mwback { dest :: CPUReg; datum :: Value; }
    deriving (Bits)

type MBuffer = FoldFIFO MInstr

-- Writeback stage operation
data WInstr
        = Wback { dest :: CPUReg; datum :: Value; }
        | Wnone
    deriving (Bits)

type WBuffer = FoldFIFO WInstr

sysMipsCPU :: MROM -> MRAM -> Module ActionSeq
sysMipsCPU imem dmem =
    module
        pc :: Reg IAddress
        pc <- mkRegU

        rf :: RegisterFile
        rf <- mkRegFile Reg1 Reg31

        bd :: UBuffer
        bd <- mkFIFO

        be :: DBuffer
        be <- mkFoldFIFO

        bm :: MBuffer
        bm <- mkFoldFIFO

        bw :: WBuffer
        bw <- mkFoldFIFO

        running :: Reg Bool
        running <- mkReg False

        interface
            start = action { running := True; pc := 0 } when not running
            done = not running

        addRules $
            mkFetchRules running pc bd imem               <+>
            mkDecodeRules running pc rf bd be bm bw dmem  <+>
            mkExecuteRules be bm                          <+>
            mkMemoryRules bm bw dmem                      <+>
            mkWritebackRules rf bw


--------------------------------------------------------------------------------

-- INSTRUCTION FETCH STAGE --
mkFetchRules :: Bool -> Reg IAddress -> UBuffer -> MROM -> Rules
mkFetchRules run pc bd imem =
    rules
      "Fetch":
        when run
          ==> action {
                pc := pc + 1;
                bd.enq (UInstr { pc = pc; instr = imem.read pc });
              }

--------------------------------------------------------------------------------

-- Possible write back, just used when computing bypasses.
data WBack
        = WBNone                -- no write-back
        | WBWait CPUReg                -- waiting for value
        | WBDo CPUReg Value        -- ready to write
    deriving (Bits)

class HasWBack a where
    getWBack :: a -> WBack
instance HasWBack DInstr where
    getWBack (Owback { dest; datum }) = WBDo dest datum
    getWBack (Oload { dest }) = WBWait dest
    getWBack (Oarith { dest }) = WBWait dest
    getWBack _ = WBNone
instance HasWBack MInstr where
    getWBack (Mwback { dest; datum }) = WBDo dest datum
    getWBack (Mload { dest }) = WBWait dest
    getWBack _ = WBNone
instance HasWBack WInstr where
    getWBack (Wback { dest; datum }) = WBDo dest datum
    getWBack Wnone = WBNone

getWB :: (HasWBack b) => (a -> b) -> FoldFIFO a -> WBack
getWB g f = f.foldr (\ x _ -> getWBack (g x)) WBNone

fwd :: CPUReg -> WBack -> Value -> Value
fwd r (WBDo r' v) _ when r == r' =  v
fwd _ _           v              =  v

noStall :: CPUReg -> WBack -> Bool
noStall r (WBWait r') = r /= r'
noStall _ _ = True

-- INSTRUCTION DECODE STAGE --
mkDecodeRules :: Reg Bool -> Reg IAddress -> RegisterFile -> UBuffer -> DBuffer ->
                 MBuffer -> WBuffer -> MRAM -> Rules
mkDecodeRules running pc rf bd be bm bw dmem =
    let
        -- Bypass points
        bps :: List WBack
--        bps = getWB      id be :> getWB            id bm :> getWB id bw :> Nil    -- take bypasses from fifo outputs
        bps = getWB execute be :> getWB (memory dmem) bm :> getWB id bw :> Nil    -- take bypasses from fifo inputs

        -- XXX wrong for FIFOs bigger than 1
        stallFree :: CPUReg -> Bool
        stallFree r = all (noStall r) bps

        regValue :: CPUReg -> Value
        regValue Reg0 = 0
        regValue r = foldr (fwd r) (rf.sub r) bps

        pc' :: IAddress
        pc' = bd.first.pc

        instr :: Instruction
        instr = unpack bd.first.instr

        retPC :: CPUReg -> DInstr
--        retPC rd = Oarith { dest = rd; aop = Aadd; v1 = iToV pc'; v2 = 8; }       -- add in ALU
        retPC rd = Owback { dest = rd; datum = iToV pc' + 8; }                    -- separate adder

    in
        rules
          "D-LUI":
            when Immediate { op = LUI; rt; imm } <- instr
              ==> action {
                    bd.deq;
                    be.enq (Owback { dest = rt; datum = vZeroExt imm << 16 })
                }

          "D-Immediate":
            when Immediate { op; rs; rt; imm } <- instr, op /= LUI,
                 not (isCmpRegsOp op || isCmpZeroImmOp op || isStoreOp op || isLoadOp op),
                 stallFree rs
              ==> action {
                    bd.deq;
                    be.enq (Oarith { dest = rt; aop = opToAOp op; v1 = regValue rs; v2 = vSignExt imm; })
                }

          "D-Store":
            when Immediate { op; rs; rt; imm } <- instr, isStoreOp op,
                 stallFree rs, stallFree rt
              ==> action {
                    bd.deq;
                    be.enq (Ostore { datum = regValue rt; addr = regValue rs; offs = vSignExt imm; })
                }

          "D-Load":
            when Immediate { op; rs; rt; imm } <- instr, isLoadOp op,
                 stallFree rs
              ==> action {
                    bd.deq;
                    be.enq (Oload { dest = rt; addr = regValue rs; offs = vSignExt imm; })
                }

          "D-Break":
            when Register { funct = BREAK } <- instr
             ==> action { running := False }

          "D-Register":
            when Register { rs; rt; rd; sa; funct } <- instr, not (isJumpFunct funct), funct /= BREAK,
                 stallFree rs, isShiftOp funct || stallFree rt
              ==> action {
                    bd.deq;
                    be.enq (Oarith {    dest = rd;
                                        aop = functToAOp funct;
                                        v1 = regValue rs;
                                        v2 = if isShiftOp funct
                                                then vZeroExt sa
                                                else regValue rt; })
                }

          "D-JumpRegister":
            when Register { rs; funct = JR } <- instr,
                 stallFree rs
              ==> action {
                    bd.clear;
                    pc := vToI (regValue rs);
                  }

          "D-JumpAndLinkRegister":
            when Register { rs; rd; funct = JALR } <- instr,
                 stallFree rs
              ==> action {
                    bd.clear;
                    pc := vToI (regValue rs);
                    be.enq (retPC rd)
                }

          "D-Jump":
            when Jump { op = J; target } <- instr
              ==> action {
                    bd.clear;
                    pc := jumpExtend pc' target;
                  }

          "D-JumpAndLink":
            when Jump { op = JAL; target } <- instr
              ==> action {
                    bd.clear;
                    pc := jumpExtend pc' target;
                    be.enq (retPC Reg31)
                }

          "D-BranchCmpRegs":
            when Immediate { op; rs; rt; imm } <- instr, isCmpRegsOp op,
                 stallFree rs, stallFree rt
              ==> if compareROp op (regValue rs) (regValue rt) then action {
                     pc := pc' + 1 + iaSignExt imm;
                     bd.clear;
                  } else
                     bd.deq

          "D-BranchCmpZeroImm":
            when Immediate { op; rs; imm } <- instr, isCmpZeroImmOp op,
                 stallFree rs
              ==> if compareZImmOp op (regValue rs) then action {
                     pc := pc' + 1 + iaSignExt imm;
                     bd.clear;
                  } else
                     bd.deq

          "D-BranchCmpZeroRegImm":
            when RegImm { op; rs; imm } <- instr, isCmpZeroRegImmOp op,
                 stallFree rs
              ==> if compareZRegImmOp op (regValue rs) then action {
                     pc := pc' + 1 + iaSignExt imm;
                     bd.clear;
                  } else
                     bd.deq

          "D-BranchCmpZeroAndLink":
            when RegImm { op; rs; imm } <- instr, isCmpZALOp op,
                 stallFree rs
              ==> if compareZALOp op (regValue rs) then action {
                     pc := pc' + 1 + iaSignExt imm;
                     bd.clear;
                     be.enq (retPC Reg31)
                  } else
                     bd.deq

--------------------------------------------------------------------------------

-- ALU STAGE --
mkExecuteRules :: DBuffer -> MBuffer -> Rules
mkExecuteRules be bm =
        rules
          "Execute":
            when True
              ==> action {
                    bm.enq (execute be.first);
                    be.deq;
                  }

computeOp :: AOp -> Value -> Value -> Value
computeOp Aadd  v1 v2 = v1 + v2
computeOp Aand  v1 v2 = v1 & v2
computeOp Asub  v1 v2 = v1 - v2
computeOp Anor  v1 v2 = invert (v1 | v2)
computeOp Aor   v1 v2 = v1 | v2
computeOp Axor  v1 v2 = v1 ^ v2
computeOp Aslt  v1 v2 = if v1 `vLT` v2 then 1 else 0
computeOp AsltU v1 v2 = if v1   <   v2 then 1 else 0
computeOp Asll  v1 v2 = v1   <<   vToNat v2
computeOp Asrl  v1 v2 = v1   >>   vToNat v2
computeOp Asra  v1 v2 = v1 `vSRA` vToNat v2

-- Compute the value for arithmetic instructions, pass on data for others.
execute :: DInstr -> MInstr
execute (Oload  { dest;  addr;  offs }) = Mload  { dest  = dest;  addr = addr + offs }
execute (Ostore { datum; addr;  offs }) = Mstore { datum = datum; addr = addr + offs }
execute (Oarith { dest; aop; v1; v2; }) = Mwback { dest = dest; datum = computeOp aop v1 v2 }
execute (Owback { dest;        datum }) = Mwback { dest = dest; datum = datum }

--------------------------------------------------------------------------------

-- MEMORY STAGE --
mkMemoryRules :: MBuffer -> WBuffer -> MRAM -> Rules
mkMemoryRules bm bw dmem =
        rules
          "Memory":
            when True
              ==> action {
                    bw.enq (memory dmem bm.first);
                    bm.deq;
                    case bm.first of {
                        Mstore { datum; addr } -> dmem.write (vToA addr) datum;
                        _ -> action { };
                    }
                  }

-- Generate the register file write-back from the memory instruction.
memory :: MRAM -> MInstr -> WInstr
memory dmem (Mload  { dest;  addr }) = Wback { dest = dest; datum = dmem.read (vToA addr) }
memory dmem (Mstore {             }) = Wnone
memory dmem (Mwback { dest; datum }) = Wback { dest = dest; datum = datum }

--------------------------------------------------------------------------------

-- WRITEBACK STAGE --
mkWritebackRules :: RegisterFile -> WBuffer -> Rules
mkWritebackRules rf bw =
    rules
      "Wback":
        when True
          ==> action {
                bw.deq;
                case bw.first of {
                    Wback { dest; datum } when dest /= Reg0 -> rf.upd dest datum;
                    _ -> action { }
                }
              }
